import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
import matplotlib.pyplot as plt
from gymnasium.wrappers import NormalizeObservation
import wandb
import torch.multiprocessing as mp
from gymnasium.vector import SyncVectorEnv
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib import TRPO  # TRPO is in sb3_contrib, not in the main stable_baselines3 package

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(3)
np.random.seed(3)
random.seed(3)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return (torch.FloatTensor(state).to(device),
                torch.FloatTensor(action).to(device),
                torch.FloatTensor(reward).to(device),
                torch.FloatTensor(next_state).to(device),
                torch.FloatTensor(done).to(device))

    def get_all_states_and_actions(self):
        """Return all stored states and actions for comparison."""
        states = [entry[0] for entry in self.buffer]
        actions = [entry[1] for entry in self.buffer]
        return states, actions

    def free(self):
        self.buffer.clear()

    def __len__(self):
        return len(self.buffer)

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, log_std_min=-20, log_std_max=2):
        super(PolicyNetwork, self).__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)

        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob

    def get_log_prob(self, state, action):
        """Calculate log probability of a given action for a state"""
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)

        # Inverse of tanh
        action_tanh_inv = torch.atanh(torch.clamp(action, -0.999, 0.999))

        log_prob = normal.log_prob(action_tanh_inv)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        return log_prob

class SAC_DPO:
    def __init__(self, state_dim, action_dim, hidden_dim, policy_lr=3e-4, Q_lr=1e-3, beta=0.5, alpha=0.01, gamma=0.99, tau=0.001):
        self.policy = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.q1 = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.q2 = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.target_q1 = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.target_q2 = QNetwork(state_dim, action_dim, hidden_dim).to(device)

        self.target_q1.load_state_dict(self.q1.state_dict())
        self.target_q2.load_state_dict(self.q2.state_dict())

        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=policy_lr)
        self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=Q_lr)
        self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=Q_lr)

        self.beta = beta
        self.alpha = alpha  # Temperature parameter for SAC
        self.gamma = gamma
        self.tau = tau  # Soft update parameter
        self.state_dim = state_dim  # Store state dimension

    def get_q_value(self, state, action):
        q1 = self.target_q1(state, action)
        q2 = self.target_q2(state, action)
        return torch.min(q1, q2)

    def find_closest_state_action(self, current_state, replay_buffer):
        """
        Find the closest state in the replay buffer to the current state
        and return the corresponding action along with the mean state.

        Returns:
            tuple: (closest_action, mean_state)
        """
        if len(replay_buffer) == 0:
            # If buffer is empty, return a random action and the current state
            return np.random.uniform(-1, 1, size=self.state_dim), current_state

        states, actions = replay_buffer.get_all_states_and_actions()

        # Convert states to tensors for distance computation
        current_state_tensor = torch.FloatTensor(current_state).to(device)
        states_tensor = torch.FloatTensor(states).to(device)

        # Compute Euclidean distances
        distances = torch.norm(states_tensor - current_state_tensor.unsqueeze(0), dim=1)

        # Find the index of the closest state
        closest_index = torch.argmin(distances).item()

        # Get the closest state
        closest_state = states[closest_index]

        # Calculate the mean state
        mean_state = np.mean([current_state, closest_state], axis=0)
        print(current_state, closest_state)
        # Return the action corresponding to the closest state and the mean state
        return actions[closest_index], mean_state

    def select_action(self, state):
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
            action, _ = self.policy.sample(state_tensor)
        return action.squeeze(0).cpu().numpy()

    def sample_action_pairs_for_batch(self, states, actions1, replay_buffer):
        """
        For each state in the batch, find a pair of actions (chosen and rejected)
        based on their Q-values using the closest state approach, vectorized for speed.
        """
        batch_size = states.size(0)

        # Sample actions for all states in the batch at once
        actions1, log_probs1 = self.policy.sample(states)

        # Get all states and actions from replay buffer
        buffer_states, buffer_actions = replay_buffer.get_all_states_and_actions()

        if len(buffer_states) == 0:
            # If buffer is empty, use random actions as the second set
            actions2 = torch.FloatTensor(np.random.uniform(-1, 1, size=(batch_size, self.state_dim))).to(device)
            mean_states = states.clone()  # Use original states as mean states
            log_probs2 = self.policy.get_log_prob(states, actions2)
        else:
            # Convert buffer states to tensor for distance computation
            buffer_states_tensor = torch.FloatTensor(buffer_states).to(device)
            buffer_actions_tensor = torch.FloatTensor(buffer_actions).to(device)

            # Compute pairwise distances between all batch states and all buffer states
            # Reshape for broadcasting: [batch_size, 1, state_dim] - [1, buffer_size, state_dim]
            states_expanded = states.unsqueeze(1)
            buffer_states_expanded = buffer_states_tensor.unsqueeze(0)




            # Compute distances: [batch_size, buffer_size]
            distances = torch.norm(states_expanded - buffer_states_expanded, dim=2)
            # Create a mask for identical states (distance = 0)
            identical_mask = (distances < 1e-10)

            # Set distances of identical states to a large value so they won't be selected
            large_value = 1e10
            distances = distances + identical_mask.float() * large_value


            # Find indices of closest states for each batch state
            closest_indices = torch.argmin(distances, dim=1)

            # Get closest actions and states
            actions2 = buffer_actions_tensor[closest_indices]
            closest_states = buffer_states_tensor[closest_indices]

            # Calculate mean states
            mean_states = (states + closest_states) / 2.0
            # print(states[0], closest_states[0])
            # Get log probs for the second set of actions
            log_probs2 = self.policy.get_log_prob(states, actions2)
            # log_probs1 = self.policy.get_log_prob(states, actions1)

        # Compute Q-values for both sets of actions
        q_values1 = self.get_q_value(states, actions1)
        q_values2 = self.get_q_value(states, actions2)
        # Compare Q-values to determine chosen and rejected actions
        better_q1 = (q_values1 > q_values2).squeeze(-1)

        # Create tensors for chosen and rejected actions and log probs
        chosen_actions = torch.zeros_like(actions1)
        chosen_log_probs = torch.zeros_like(log_probs1)
        rejected_actions = torch.zeros_like(actions1)
        rejected_log_probs = torch.zeros_like(log_probs1)

        # For states where action1 has higher Q-value
        if better_q1.any():
            chosen_actions[better_q1] = actions1[better_q1]
            chosen_log_probs[better_q1] = log_probs1[better_q1]
            rejected_actions[better_q1] = actions2[better_q1]
            rejected_log_probs[better_q1] = log_probs2[better_q1]

        # For states where action2 has higher Q-value
        if (~better_q1).any():
            chosen_actions[~better_q1] = actions2[~better_q1]
            chosen_log_probs[~better_q1] = log_probs2[~better_q1]
            rejected_actions[~better_q1] = actions1[~better_q1]
            rejected_log_probs[~better_q1] = log_probs1[~better_q1]

        # Ensure we have a meaningful difference between chosen and rejected
        # Add a small epsilon to ensure chosen is always better than rejected
        epsilon = 0
        valid_pairs = (chosen_log_probs - rejected_log_probs).abs() > epsilon

        # If we don't have valid pairs, create artificial ones
        if not valid_pairs.any():
            # Create a small artificial difference
            rejected_log_probs = rejected_log_probs + 0.1

        return chosen_actions, chosen_log_probs, rejected_actions, rejected_log_probs, mean_states

    def update_q_network(self, q_network, optimizer, state, action, target_q):
        """Helper function to update a single Q network"""
        q_value = q_network(state, action)
        loss = F.mse_loss(q_value, target_q)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    def SAC_update(self, replay_buffer, batch_size):
        if len(replay_buffer) < batch_size:
            return 0, 0, 0

        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        # Calculate target Q values (this part remains sequential)
        with torch.no_grad():
            next_action, next_log_prob = self.policy.sample(next_state)
            target_q1 = self.target_q1(next_state, next_action)
            target_q2 = self.target_q2(next_state, next_action)
            target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_prob
            target_q = reward + (1 - done) * self.gamma * target_q

        # Update Q networks in parallel
        q1_future = torch.jit.fork(self.update_q_network, self.q1, self.q1_optimizer, state, action, target_q)
        q2_future = torch.jit.fork(self.update_q_network, self.q2, self.q2_optimizer, state, action, target_q)

        # Wait for Q network updates to complete
        q1_loss = torch.jit.wait(q1_future)
        q2_loss = torch.jit.wait(q2_future)

        # SAC policy update
        new_action, log_prob = self.policy.sample(state)
        q1_new = self.q1(state, new_action)
        q2_new = self.q2(state, new_action)
        q_new = torch.min(q1_new, q2_new)

        policy_loss = (self.alpha * log_prob - q_new).mean()

        # Update policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Soft update target networks in parallel
        self.soft_update_target_networks()

        return q1_loss, q2_loss, policy_loss.item()

    def soft_update_target_networks(self):
        """Perform soft update of target networks"""
        # Use parallel execution for target network updates
        with torch.no_grad():
            # Update target_q1 in parallel
            for target_param, param in zip(self.target_q1.parameters(), self.q1.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            # Update target_q2 in parallel
            for target_param, param in zip(self.target_q2.parameters(), self.q2.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def dfa_update(self, replay_buffer, batch_size):
        if len(replay_buffer) < batch_size:
            return 0, 0, 0

        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        # Update Q networks using the chosen action
        with torch.no_grad():
            next_action, next_log_prob = self.policy.sample(next_state)
            target_q1 = self.target_q1(next_state, next_action)
            target_q2 = self.target_q2(next_state, next_action)
            target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_prob
            target_q = reward + (1 - done) * self.gamma * target_q

        # Update Q networks in parallel
        q1_future = torch.jit.fork(self.update_q_network, self.q1, self.q1_optimizer, state, action, target_q)
        q2_future = torch.jit.fork(self.update_q_network, self.q2, self.q2_optimizer, state, action, target_q)

        # Wait for Q network updates to complete
        q1_loss = torch.jit.wait(q1_future)
        q2_loss = torch.jit.wait(q2_future)

        # DFA loss computation
        chosen_actions, chosen_log_probs, rejected_actions, rejected_log_probs, mean_states = self.sample_action_pairs_for_batch(state, action, replay_buffer)

        # Compute DFA loss
        dfa_loss = self.dfa_loss(mean_states, chosen_actions, chosen_log_probs, rejected_actions, rejected_log_probs)

        # Ensure loss is positive for reporting
        dfa_loss_value = abs(dfa_loss.item())

        # Policy update
        self.policy_optimizer.zero_grad()
        dfa_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=1.0)
        self.policy_optimizer.step()

        # Soft update target networks in parallel
        self.soft_update_target_networks()

        return q1_loss, q2_loss, dfa_loss_value

    def dfa_loss(self, states, chosen_actions, chosen_log_probs, rejected_actions, rejected_log_probs):
        # Compute the DFA loss
        log_ratio = chosen_log_probs - rejected_log_probs

        # Apply temperature scaling
        scaled_log_ratio = self.beta * log_ratio

        # Compute loss with sigmoid (standard DFA loss)
        loss = -torch.mean(torch.log(torch.sigmoid(scaled_log_ratio) + 1e-8))

        # Optional: Add entropy regularization (but make it positive)
        entropy_reg = torch.mean(torch.abs(chosen_log_probs) + torch.abs(rejected_log_probs)) * 0.01
        # entropy_reg = 0
        # Return positive loss (we want to minimize this)
        return loss + entropy_reg

def train_trpo_sb3(env_name, gamma=0.99, learning_rate=1e-3, batch_size=256, n_steps=2048, max_kl=0.01):
    """Train a TRPO agent using Stable Baselines 3"""
    run_name = f"As-paper-Env={env_name}_TRPO_SB3_gamma={gamma}_lr={learning_rate}_max_kl={max_kl}"

    # Initialize wandb
    wandb.init(
        project="sac-dfa-rl",
        name=run_name,
        config={
            "algorithm": "TRPO_SB3",
            "env": env_name,
            "gamma": gamma,
            "learning_rate": learning_rate,
            "batch_size": batch_size,
            "n_steps": n_steps,
            "max_kl": max_kl
        }
    )

    # Create vectorized environment
    vec_env = make_vec_env(env_name, n_envs=32, seed=3,
                          wrapper_class=NormalizeObservation)

    # Create TRPO agent with specified hyperparameters
    model = TRPO(
        "MlpPolicy",
        vec_env,
        gamma=gamma,
        learning_rate=learning_rate,
        batch_size=batch_size,
        n_steps=n_steps,
        target_kl=max_kl,
        verbose=1,
        tensorboard_log="./trpo_tensorboard/"
    )

    # Setup callback for wandb logging
    class WandbCallback:
        def __init__(self):
            self.episode_count = 0
            self.total_timesteps = 0

        def __call__(self, locals, globals):
            if 'infos' in locals and len(locals['infos']) > 0:
                # Extract episode rewards if available
                if 'episode' in locals['infos'][0]:
                    episode_info = locals['infos'][0]['episode']
                    self.episode_count += 1
                    self.total_timesteps = locals['self'].num_timesteps

                    wandb.log({
                        "episode": self.episode_count,
                        "episode_reward": episode_info['r'],
                        "episode_length": episode_info['l'],
                        "total_timesteps": self.total_timesteps
                    }, step=self.total_timesteps)

                    print(f"Episode {self.episode_count}: Reward = {episode_info['r']:.2f}, Length = {episode_info['l']}")
            return True

    # Train the agent
    model.learn(
        total_timesteps=10000000,  # Adjust based on your needs
        callback=WandbCallback(),
        progress_bar=True
    )

    # Save the model
    model.save(f"trpo_sb3_{run_name}")
    wandb.save(f"trpo_sb3_{run_name}.zip")

    wandb.finish()

    return model

def train_sac_dfa(agent=None, dfa=False, beta=0.5, alpha=0.001, gamma=0.99, tau=0.001, env_name=None, policy_lr=None):
    run_name = f"Env={env_name}_{'DFA' if dfa else 'SAC'}_beta={beta}_alpha={alpha}_gamma={gamma}_tau={tau}_policy_lr={policy_lr}"
    batch_size = 256
    hidden_dim = 64

    wandb.init(
        project="sac-dfa-rl",
        name=run_name,
        config={
            "algorithm": "DFA" if dfa else "SAC",
            "env": env_name,
            "beta": beta,
            "alpha": alpha,
            "gamma": gamma,
            "tau": tau,
            "hidden_dim": hidden_dim,
            "batch_size": batch_size,
            "device": device
        }
    )

    num_envs = 32  # Adjust based on your CPU cores

    def make_env():
        def _thunk():
            env = gym.make(env_name, )
            # env = NormalizeObservation(env)
            return env
        return _thunk

    envs = SyncVectorEnv([make_env() for _ in range(num_envs)])

    state_dim = envs.single_observation_space.shape[0]
    action_dim = envs.single_action_space.shape[0]

    agent = SAC_DPO(state_dim, action_dim, hidden_dim, beta=beta, alpha=alpha, gamma=gamma, tau=tau, policy_lr=policy_lr)
    replay_buffer = ReplayBuffer(20000)

    episodes = 50000

    total_timesteps = 0
    update_step = 0

    wandb.watch(agent.policy, log='gradients', log_freq=5000)
    wandb.watch(agent.q1, log='gradients', log_freq=5000)

    for episode in range(episodes):
        states, _ = envs.reset()
        episode_rewards = np.zeros(num_envs)
        dones = np.zeros(num_envs, dtype=bool)
        truncateds = np.zeros(num_envs, dtype=bool)
        episode_steps = np.zeros(num_envs)

        while not np.all(dones | truncateds):
            actions = np.array([agent.select_action(state) for state in states])
            next_states, rewards, dones_step, truncateds_step, _ = envs.step(actions)
            total_timesteps += num_envs

            for idx in range(num_envs):
                if not dones[idx] and not truncateds[idx]:
                    replay_buffer.push(states[idx], actions[idx], [rewards[idx]], next_states[idx], [dones_step[idx]])
                    episode_rewards[idx] += rewards[idx]
                    episode_steps[idx] += 1

            states = next_states
            dones |= dones_step
            truncateds |= truncateds_step

            if len(replay_buffer) > batch_size:
                if dfa:
                    q1_loss, q2_loss, dfa_loss = agent.dfa_update(replay_buffer, batch_size)
                    wandb.log({
                        "q1_loss": q1_loss,
                        "q2_loss": q2_loss,
                        "dfa_loss": dfa_loss,
                        "timestep": total_timesteps,
                        "update_step": update_step
                    }, step=total_timesteps)
                else:
                    q1_loss, q2_loss, policy_loss = agent.SAC_update(replay_buffer, batch_size)
                    wandb.log({
                        "q1_loss": q1_loss,
                        "q2_loss": q2_loss,
                        "policy_loss": policy_loss,
                        "timestep": total_timesteps,
                        "update_step": update_step
                    }, step=total_timesteps)
                update_step += 1

        wandb.log({
            "episode": episode,
            "episode_reward": np.mean(episode_rewards),
            "episode_length": np.mean(episode_steps),
            "buffer_size": len(replay_buffer),
            "total_timesteps": total_timesteps
        }, step=total_timesteps)

        print(f"Episode {episode+1}: Avg Reward = {np.mean(episode_rewards):.2f}, Avg Steps = {np.mean(episode_steps):.2f}, Buffer size = {len(replay_buffer)}")

    torch.save(agent.policy.state_dict(), f"policy_{run_name}.pt")
    torch.save(agent.q1.state_dict(), f"q1_{run_name}.pt")
    torch.save(agent.q2.state_dict(), f"q2_{run_name}.pt")

    wandb.save(f"policy_{run_name}.pt")
    wandb.save(f"q1_{run_name}.pt")
    wandb.save(f"q2_{run_name}.pt")

    wandb.finish()

    return agent


if __name__ == "__main__":
    # Enable parallel processing
    if torch.cuda.is_available():
        # Use all available GPUs
        torch.multiprocessing.set_start_method('spawn')

    # Choose which algorithm to run
    # algorithm = "trpo_sb3"  # Options: "sac", "dfa", "trpo_sb3"
    algorithm = "dfa"  # Options: "sac", "dfa", "trpo_sb3"
    # algorithm = "sac"  # Options: "sac", "dfa", "trpo_sb3"

    # Common hyperparameters
    gamma = 0.99
    env_name = "Pendulum-v1"
    # env_name = "MountainCarContinuous-v0"

    if algorithm == "trpo_sb3":
        # TRPO Stable Baselines 3 parameters
        learning_rate = 1e-3
        max_kl = 0.01
        batch_size = 256
        n_steps = 2048  # Typical value for TRPO

        agent = train_trpo_sb3(
            env_name=env_name,
            gamma=gamma,
            learning_rate=learning_rate,
            batch_size=batch_size,
            n_steps=n_steps,
            max_kl=max_kl
        )
    else:
        # SAC/DFA parameters
        use_dfa = (algorithm == "dfa")
        beta = 0.2
        alpha = 0.01
        tau = 0.005
        policy_lr = 1e-3

        agent = train_sac_dfa(
            dfa=use_dfa,
            beta=beta,
            alpha=alpha,
            gamma=gamma,
            tau=tau,
            env_name=env_name,
            policy_lr=policy_lr
        )

    # Print final statistics
    print("\nTraining Complete!")
    print(f"Model weights saved for {algorithm.upper()} algorithm")